{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3151a0f1-2f37-4202-aecd-5fee9f9cc909",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q zarr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d59d417-2348-486f-aaab-9d43b5418e2a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from os.path import join\n",
    "\n",
    "import anndata\n",
    "import dask\n",
    "import dask.array as da\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "dask.config.set(scheduler='threads');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a962492-ae17-4612-926f-d5c14ae31246",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def read_X(path):\n",
    "    return anndata.read_h5ad(path).X\n",
    "\n",
    "\n",
    "def read_obs(path):\n",
    "    obs = anndata.read_h5ad(path, backed='r').obs\n",
    "    obs['tech_sample'] = obs.dataset_id.astype(str) + '_' + obs.donor_id.astype(str)\n",
    "    return obs\n",
    "\n",
    "\n",
    "def read_var(path):\n",
    "    return anndata.read_h5ad(path, backed='r').var\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1152c31d-90b6-42cf-8e53-c5ffe8caafed",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a46eb975-b4c8-4c41-af47-6c8873ed4221",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a67522c8-192d-4657-999f-13cc24ea0470",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "BASE_PATH = '/mnt/dssfs02/cxg_census/h5ad_raw_2023_05_15'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "116db53b-3cb9-4b98-b181-e2334fe638d9",
   "metadata": {},
   "source": [
    "### Convert to zarr + DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd764cbd-52c6-4833-8910-21e5ba14ef98",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "files = [\n",
    "    join(BASE_PATH, file) for file \n",
    "    in sorted(os.listdir(BASE_PATH), key=lambda x: int(x.split('.')[0])) \n",
    "    if file.endswith('.h5ad')\n",
    "]\n",
    "\n",
    "# read obs\n",
    "print('Loading obs...')\n",
    "obs = pd.concat([read_obs(file) for file in files]).reset_index(drop=True)\n",
    "for col in obs.columns:\n",
    "    if obs[col].dtype == object:\n",
    "        obs[col] = obs[col].astype('category')\n",
    "        obs[col].cat.remove_unused_categories()\n",
    "# read var\n",
    "print('Loading var...')\n",
    "var = read_var(files[0])\n",
    "# read X\n",
    "print('Loading X...')\n",
    "split_lens = [len(split) for split in np.array_split(obs.soma_joinid.to_numpy(), 20)]\n",
    "X = da.concatenate([\n",
    "    da.from_delayed(dask.delayed(read_X)(file), (split_len, len(var)), dtype='f4') \n",
    "    for file, split_len in zip(files, split_lens)\n",
    "]).rechunk((32768, -1)).persist()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38c08a38-3153-4e9f-8ba6-744670230044",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "X"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d19cc3c8-9e20-437a-8605-d5350757180e",
   "metadata": {},
   "source": [
    "### Create train, val, test split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5b43814-ea3e-4680-9cb6-767d5624fbb7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from statistics import mode\n",
    "from scipy.sparse import csr_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2284694-ae8a-4afc-8858-33359c99a075",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from math import ceil\n",
    "\n",
    "\n",
    "def get_split(samples, val_split: float = 0.15, test_split: float = 0.15, seed=1):\n",
    "    rng = np.random.default_rng(seed=seed)\n",
    "\n",
    "    samples = np.array(samples)\n",
    "    rng.shuffle(samples)\n",
    "    n_samples = len(samples)\n",
    "\n",
    "    n_samples_val = ceil(val_split * n_samples)\n",
    "    n_samples_test = ceil(test_split * n_samples)\n",
    "    n_samples_train = n_samples - n_samples_val - n_samples_test\n",
    "\n",
    "    return {\n",
    "        'train': samples[:n_samples_train],\n",
    "        'val': samples[n_samples_train:(n_samples_train + n_samples_val)],\n",
    "        'test': samples[(n_samples_train + n_samples_val):]\n",
    "    }\n",
    "\n",
    "\n",
    "def subset(splits, frac):\n",
    "    assert 0. < frac <= 1.\n",
    "    if frac == 1.:\n",
    "        return splits\n",
    "    else:\n",
    "        return splits[:ceil(frac * len(splits))]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3650eeb4-9d73-45ab-b6f3-6e91f4a2d140",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# subsample_fracs: 0.15, 0.3, 0.5, 0.7, 1.\n",
    "SUBSAMPLE_FRAC = 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f14ea53f-9cbb-4359-aaa8-7acf51217577",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "splits = {'train': [], 'val': [], 'test': []}\n",
    "tech_sample_splits = get_split(obs.tech_sample.unique().tolist())\n",
    "for x in ['train', 'val', 'test']:\n",
    "    # tech_samples are already shuffled in the get_split method -> just subselect to subsample donors\n",
    "    if x == 'train':\n",
    "        # only subset training data set\n",
    "        splits[x] = obs[obs.tech_sample.isin(subset(tech_sample_splits[x], SUBSAMPLE_FRAC))].index.to_numpy()\n",
    "    else:\n",
    "        splits[x] = obs[obs.tech_sample.isin(tech_sample_splits[x])].index.to_numpy()\n",
    "\n",
    "splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b59614c4-27a8-464e-809c-75860636b368",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "assert len(np.intersect1d(splits['train'], splits['val'])) == 0\n",
    "assert len(np.intersect1d(splits['train'], splits['test'])) == 0\n",
    "assert len(np.intersect1d(splits['val'], splits['train'])) == 0\n",
    "assert len(np.intersect1d(splits['val'], splits['test'])) == 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3bd131f-32fb-4f7e-b8e0-516e0caf73a3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(f\"train: {len(obs.loc[splits['train'], :]):,} cells\")\n",
    "print(f\"val: {len(obs.loc[splits['val'], :]):,} cells\")\n",
    "print(f\"test: {len(obs.loc[splits['test'], :]):,} cells\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ddcd10e-d67f-416d-ba65-9f51124e8459",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(f\"train: {len(np.unique(obs.loc[splits['train'], 'cell_type']))} celltypes\")\n",
    "print(f\"val: {len(np.unique(obs.loc[splits['val'], 'cell_type']))} celltypes\")\n",
    "print(f\"test: {len(np.unique(obs.loc[splits['test'], 'cell_type']))} celltypes\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5d72e17-7a52-4a93-9c7b-aaaeb6653779",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(f\"train: {len(np.unique(obs.loc[splits['train'], 'tech_sample']))} donors\")\n",
    "print(f\"val: {len(np.unique(obs.loc[splits['val'], 'tech_sample']))} donors\")\n",
    "print(f\"test: {len(np.unique(obs.loc[splits['test'], 'tech_sample']))} donors\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af5351c-bdf9-4f4f-8323-8975ac0945fc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(seed=1)\n",
    "\n",
    "splits['train'] = rng.permutation(splits['train'])\n",
    "splits['val'] = rng.permutation(splits['val'])\n",
    "splits['test'] = rng.permutation(splits['test'])\n",
    "\n",
    "splits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee8aa4f4-a72a-43b0-9872-dff02654894d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "69bc24e5-c047-46d2-94f2-ea901cc68438",
   "metadata": {},
   "source": [
    "### Save data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "885bce5a-4389-411c-869c-c76ef126fdeb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "SAVE_PATH = f'/mnt/dssfs02/cxg_census/data_2023_05_15'\n",
    "if SUBSAMPLE_FRAC < 1.:\n",
    "    SAVE_PATH = SAVE_PATH + f'_subsample_{round(SUBSAMPLE_FRAC * 100)}'\n",
    "\n",
    "CHUNK_SIZE = 16384"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dfe0471-2295-4e31-b4bb-bfa996f43d79",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if SUBSAMPLE_FRAC < 1.:\n",
    "    # only save train data for subset stores\n",
    "    # val + test can be copyed later from non subset store\n",
    "    splits_to_save = ['train']\n",
    "else:\n",
    "    splits_to_save = ['train', 'val', 'test']\n",
    "\n",
    "\n",
    "for split, idxs in splits.items():\n",
    "    if split in splits_to_save:\n",
    "        # out-of-order indexing is on purpose here as we want to shuffle the data to break up data sets\n",
    "        X_split = X[idxs, :].rechunk((CHUNK_SIZE, -1))\n",
    "        obs_split = obs.loc[idxs, :]\n",
    "\n",
    "        save_dir = join(SAVE_PATH, split)\n",
    "        os.makedirs(save_dir)\n",
    "\n",
    "        var.to_parquet(path=join(save_dir, 'var.parquet'), engine='pyarrow', compression='snappy', index=None)\n",
    "        obs_split.to_parquet(path=join(save_dir, 'obs.parquet'), engine='pyarrow', compression='snappy', index=None)\n",
    "        da.to_zarr(\n",
    "            X_split.map_blocks(lambda xx: xx.toarray(), dtype='f4'),\n",
    "            join(save_dir, 'zarr'),\n",
    "            component='X',\n",
    "            compute=True,\n",
    "            compressor='default', \n",
    "            order='C'\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1455be19-10cc-4e0f-9133-9e9fe9814de6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
